package com.soc.auth.ratelimit;

import com.google.common.util.concurrent.RateLimiter;
import com.soc.auth.utils.IPUtil;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/**
 * 限流功能的切面
 */
@Slf4j
@Aspect
@Component
public class RateLimitAspect {

    @Autowired
    private HttpServletRequest httpServletRequest;

    // 用来存放不同 API 的 RateLimiter (该map中value为RateLimiter实例)
    private final ConcurrentHashMap<String, RateLimiter> EXISTED_RATE_LIMITERS = new ConcurrentHashMap<>();

    // @within：用于匹配标注@RateLimit类内中的所有方法，@annotation：用于匹配当前执行方法持有指定注解的方法
    // 定义切入点，只要被自定义的@RateLimit修饰的方法或者被自定义的@RateLimit修饰的类下的所有方法就是切入点
    @Pointcut("@within(com.soc.auth.ratelimit.RateLimit) || @annotation(com.soc.auth.ratelimit.RateLimit)")
    public void rateLimit() {
    }

    @Around("rateLimit()") // 带有@RateLimit注解的方法会被进行增强
    public Object around(ProceedingJoinPoint point) throws Throwable {
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        // 通过 AnnotationUtils.findAnnotation 从方法上获取 @RateLimiter 注解
        RateLimit annotation = AnnotationUtils.findAnnotation(method, RateLimit.class);
        // 若方法上找不到， 从类上获取注解
        if (annotation == null) {
            Class<?> declaringClass = method.getDeclaringClass();
            annotation = declaringClass.getAnnotation(RateLimit.class);
        }
        // 获得最终的注解信息
        RateLimit finalAnnotation = annotation;

        /*
        * RateLimiter 使用令牌桶算法，有两种模式，一种是平滑突发限流，一种是平滑预热限流
        * 令牌桶算法则是一个存放固定容量令牌的桶，按照固定速率往桶里添加令牌。桶中存放的令牌数有最大上限，超出之后就被丢弃或者拒绝。当流量或者网络请求到达时，
        * 每个请求都要获取一个令牌，如果能够获取到，则直接处理，并且令牌桶删除一个令牌。如果获取不同，该请求就要被限流，要么直接丢弃，要么在缓冲区等待。
        */

        /*
        * computeIfAbsent方法会判断map集合是否存在以method.getName()为key的value，存在就从map集合中取出该value；如果不存在会自动调用mappingFunction
        * 来生成该key对应的value值，存到map中，然后取出刚生成的value值；如果mappingFunction返回的值为null或抛出异常，则不会有记录存入map。
        */
        // key为方法的类名+方法名+请求的ip地址，这样能以用户的ip地址进行限流
        String key = method.getDeclaringClass().getName() + "." + method.getName() + "." + IPUtil.getIpAddr(httpServletRequest);
        // 从ConcurrentHashMap中根据key获得RateLimiter对象
        RateLimiter rateLimiter = EXISTED_RATE_LIMITERS.computeIfAbsent(key, new Function<String, RateLimiter>() {
            @Override
            public RateLimiter apply(String s) {
                // 这里RateLimiter采用平滑突发限流
                return RateLimiter.create(finalAnnotation.qps());
            }
        });

        // rateLimiter.tryAcquire() 为非阻塞方式获取令牌，如果获取不到令牌超过移动时间，则返回false
        if (rateLimiter!=null && rateLimiter.tryAcquire(finalAnnotation.timeout(), finalAnnotation.timeUnit())) {
            // rateLimiter 不为空且必须获取到令牌后才执行被代理的方法
            return point.proceed();
        } else {
            throw new RuntimeException("请求太多，请稍后尝试。。。");
        }
    }
}
